{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VciV_QaYuTVd"
      },
      "source": [
        "# Train cross-lingual supervised text classifiers using multilingual sentence embeddings for elite criticism detection\n",
        "\n",
        "*author:* Hauke Licht\n",
        "\n",
        "In this notebook, I train cross-lingual supervised text classifiers using multilingual sentence embeddings for detecting elite criticism (anti-elite messages) in poltical parties' tweets.\n",
        "\n",
        "I rely on two pre-trained multilingual sentence embedding models (cf. [Licht 2022](https://osf.io/384wr/)): (a) a knowledge-distilled XLM-R sentence tranformer ([Reimers and Gurevych, 2022](https://arxiv.org/abs/2004.09813)) and (b) a Language Agnostic Sentence Embedding Representations (LASER) encoder ([Artxte and Schwenk, 2019]()).\n",
        "\n",
        "I use these models to obtain tweet embeddings which I use then as features to train low-complexity classifiers (a L2-regularized linear model (using SGD) and a Multi-Layer Perceptron, respectively) for classifying tweets posted by political parties according to whether or not they contain elite-critical statements.\n",
        "\n",
        "The labeled dataset I use records 5.3K+ tweets that have been sampled from tweets posted by political parties from 20 Western countries between 2008 and early 2021.\n",
        "The annotations come from 6 crowd coders per tweet that I have aggregated into tweet-level labels using a Dawid and Skene ([1979](https://doi.org/10.2307/2346806)) annotation model (cf. [Paun et al. 2018](https://aclanthology.org/Q18-1040.pdf))."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NVKKfzcPLA8T"
      },
      "source": [
        "\n",
        "# Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQgctmDm6srd"
      },
      "source": [
        "## Set data paths"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "sB2uq4Jls4X4"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "base_path = os.path.join('..', '..')\n",
        "data_path = os.path.join(base_path, 'data')\n",
        "input_path = os.path.join(data_path, 'intermediate', 'training')\n",
        "res_dir = os.path.join(data_path, 'output', 'classifier_results')\n",
        "os.makedirs(res_dir, exist_ok = True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pejti7uxLECQ"
      },
      "source": [
        "## Load required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "YqjecWFs4iBN"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n"
          ]
        }
      ],
      "source": [
        "# for I/O\n",
        "import os\n",
        "import re\n",
        "import json\n",
        "\n",
        "# misc\n",
        "from tqdm.notebook import tqdm\n",
        "import random\n",
        "\n",
        "# for data wrangling\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import pycld2 as cld2\n",
        "\n",
        "# for sentence emebdding\n",
        "import torch\n",
        "from sentence_transformers import SentenceTransformer\n",
        "from laserembeddings import Laser\n",
        "\n",
        "# for train/test data preparation\n",
        "from sklearn.model_selection import GridSearchCV\n",
        "\n",
        "# Classifiers\n",
        "\n",
        "# for the L2-regularized Perceptron\n",
        "from sklearn.linear_model import SGDClassifier\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "\n",
        "# for evaluation\n",
        "from sklearn.metrics import classification_report"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X9x5Mt_T7564"
      },
      "source": [
        "## Check CUDA availability"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zzux8UaP75QP",
        "outputId": "e3d55792-edd5-46d9-b29e-5714b71aba46"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "cpu\n"
          ]
        }
      ],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "print(device) # this should print 'cuda'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XYhoyVxiTKOk"
      },
      "source": [
        "\n",
        "## Set global configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "NqwGZYMY63kD"
      },
      "outputs": [],
      "source": [
        "# set the seed\n",
        "SEED = 1234\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XWqZ7LGMFeHV"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQWLAKg4aSeG"
      },
      "source": [
        "## Description"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BqBm2yyeYta4"
      },
      "source": [
        "The dataset we'll load has the following columns:\n",
        "\n",
        "- `item_id` (str): Unique ID of tweet (has been constructed by concatenating ISO-3-character country code of the party posting the tweet, `user_id`, and `status_id`)\n",
        "- `user_id` (int): the ID of the account that has posted the tweet\n",
        "- `status_id` (int): the ID of the tweet\n",
        "- `labeling` (str): the label class a tweet has been assigned to (i.e., its label)\n",
        "- `text` (str): The tweet's text (in its original language)\n",
        "- `text_en` (str): The tweet's machine-translated text (into English)\n",
        "- `test_` (bool): Boolean flag indicating tweets that should in the test (not the training) data split\n",
        "\n",
        "Note that user and status IDs are integers because they can be very long.\n",
        "Hence, I'll read them as int64 types to ensure they are not corrputed.\n",
        "(Alternatively, you could just treat them as strings.)\n",
        "\n",
        "To this end, I create a dictionary mapping column names to the desired data types that I'll pass when reading the CSV file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "Wrj0Mxe2am0W"
      },
      "outputs": [],
      "source": [
        "col_types = {\n",
        "  'item_id': str,\n",
        "  'user_id': 'Int64',\n",
        "  'status_id': 'Int64',\n",
        "  'labeling': str,\n",
        "  'text': str,\n",
        "  'test_': bool\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rP65yYzRYLsb"
      },
      "source": [
        "## Download\n",
        "\n",
        "Download and read the labeled tweets dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "DsmH1rS_XKgr"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>item_id</th>\n",
              "      <th>user_id</th>\n",
              "      <th>status_id</th>\n",
              "      <th>labeling</th>\n",
              "      <th>text</th>\n",
              "      <th>text_en</th>\n",
              "      <th>test_</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>IRL_19530527_19326417276</td>\n",
              "      <td>19530527</td>\n",
              "      <td>19326417276</td>\n",
              "      <td>yes-general</td>\n",
              "      <td>The Government has become the single biggest o...</td>\n",
              "      <td>The Government has become the single biggest o...</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>IRL_22628924_120756369459642370</td>\n",
              "      <td>22628924</td>\n",
              "      <td>120756369459642370</td>\n",
              "      <td>no</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>IRL_22628924_120756369459642370</td>\n",
              "      <td>22628924</td>\n",
              "      <td>120756369459642370</td>\n",
              "      <td>no</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>ESP_83784273_1054674229541650440</td>\n",
              "      <td>83784273</td>\n",
              "      <td>1054674229541650440</td>\n",
              "      <td>no</td>\n",
              "      <td>El presidente del EBB, @andoniortuzar; la pres...</td>\n",
              "      <td>The president of EBB, @andoniortuzar; the pres...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>GRC_2808263274_1014889230466715653</td>\n",
              "      <td>2808263274</td>\n",
              "      <td>1014889230466715653</td>\n",
              "      <td>yes-specific</td>\n",
              "      <td>#BeLeventis Κατηγορείτε ως ακροδεξιούς, όσους ...</td>\n",
              "      <td>#BeLeventis Worship as far-right, those who do...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                              item_id     user_id            status_id  \\\n",
              "0            IRL_19530527_19326417276    19530527          19326417276   \n",
              "1     IRL_22628924_120756369459642370    22628924   120756369459642370   \n",
              "2     IRL_22628924_120756369459642370    22628924   120756369459642370   \n",
              "3    ESP_83784273_1054674229541650440    83784273  1054674229541650440   \n",
              "4  GRC_2808263274_1014889230466715653  2808263274  1014889230466715653   \n",
              "\n",
              "       labeling                                               text  \\\n",
              "0   yes-general  The Government has become the single biggest o...   \n",
              "1            no  As president Martin McGuinness will use his in...   \n",
              "2            no  As president Martin McGuinness will use his in...   \n",
              "3            no  El presidente del EBB, @andoniortuzar; la pres...   \n",
              "4  yes-specific  #BeLeventis Κατηγορείτε ως ακροδεξιούς, όσους ...   \n",
              "\n",
              "                                             text_en  test_  \n",
              "0  The Government has become the single biggest o...   True  \n",
              "1  As president Martin McGuinness will use his in...  False  \n",
              "2  As president Martin McGuinness will use his in...  False  \n",
              "3  The president of EBB, @andoniortuzar; the pres...  False  \n",
              "4  #BeLeventis Worship as far-right, those who do...  False  "
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "fp = os.path.join(input_path, 'training_data_pooled_samples.csv')\n",
        "dat = pd.read_csv(fp, sep = \",\", dtype = col_types)\n",
        "dat.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "btNyCm4RYfiq"
      },
      "source": [
        "Set unique IDs as index."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "MW-9z7OLYPpL"
      },
      "outputs": [],
      "source": [
        "dat.set_index('item_id', inplace = True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n9kI6uU8cXP6"
      },
      "source": [
        "## Create binary labels\n",
        "\n",
        "Let's have a look at the `labeling` values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s6gLkIe8cbi0",
        "outputId": "3fd1ed24-7706-4e47-8688-5a042b2ca2b0"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "labeling\n",
              "no              3344\n",
              "yes-general     1289\n",
              "yes-specific     768\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat.labeling.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dw4127o6chWY"
      },
      "source": [
        "The labelings indicates whether a tweet contains\n",
        "\n",
        "1. **no** elite criticism,\n",
        "2. elite criticism directed at **the elite in general**, or\n",
        "3. criticism of **specific elites**.\n",
        "\n",
        "We argue that *the essence of anti-elite rhetoric* (as a political strategy) is generalized elite criticism.\n",
        "Hence, we are mainly interested in the distinction between 'general' elite criticism and all other statements.\n",
        "Accordingly, I create a **binary** label indicator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5Et8jtZGdPM7",
        "outputId": "0456454a-ff32-4b09-d084-bd6c5721aa2d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "label_\n",
              "0    4112\n",
              "1    1289\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat['label_'] = dat.labeling == 'yes-general' # positive (negative) class label => True (False)\n",
        "dat['label_'] = dat['label_'].astype(int)  # positive (negative) class label => 1 (0)\n",
        "dat.label_.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gK7wm27EkKgH"
      },
      "source": [
        "## Preprocess tweets\n",
        "\n",
        "The sentence embedding models used below have not been pre-trained on Twitter data.\n",
        "Hence, I minally pre-process tweets to make them more like regular text: I remove URLs, and I remove hashtag and handle symbols."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "XUvzHYssk2n-"
      },
      "outputs": [],
      "source": [
        "def preprocess_tweet_text(text, handle_regex = r'@[A-Za-z0-9_]{4,15}', hashtag_regex = '#[A-Za-z0-9_]{2,240}'):\n",
        "  # covnert spaces\n",
        "  text = re.sub(r'\\s+', u'\\x20', text)\n",
        "  new_text = []\n",
        "  for t in text.split(u'\\x20'):\n",
        "    if t.startswith('http'):\n",
        "      continue\n",
        "    if re.search(re.compile(handle_regex), t) and len(t) > 1:\n",
        "      t = re.sub(r'.?@', '', t)\n",
        "      t = t[0].upper() + t[1:]\n",
        "    if re.search(re.compile(hashtag_regex), t) and len(t) > 1:\n",
        "      t = re.sub(r'#', '', t)\n",
        "      t = t[0].upper() + t[1:]\n",
        "    new_text.append(t)\n",
        "  return u'\\x20'.join(new_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WuIAgBoThb_F",
        "outputId": "6e5c8fb6-c3a3-42b2-835f-502ca33b5e9d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Når @statoil bruker 2 mill. på 16 siders reklamekampanje som går rett inn i valgkampen, har det gått for langt. https://t.co/uZTRDNM8mF\n",
            "=> Når Statoil bruker 2 mill. på 16 siders reklamekampanje som går rett inn i valgkampen, har det gått for langt.\n",
            "📺 Sinn Féin spokesperson for Climate Action @BrianStanleyTD will be live on @RTE_PrimeTime this evening from 9:35pm #rtept #ClimateActionNow 🌍 https://t.co/8nBQFGzoEt\n",
            "=> 📺 Sinn Féin spokesperson for Climate Action BrianStanleyTD will be live on RTE_PrimeTime this evening from 9:35pm Rtept ClimateActionNow 🌍\n",
            "\"@EquipoClavijo dio el paso porque era lo mejor para Canarias, para el partido y para garantizar la gobernabilidad en Canarias. Aunque finalmente no fuera posible ese acuerdo, su gesto honra su trayectoria política y supone un enorme servicio al partido”, aseguró @BarraganJM\n",
            "=> EquipoClavijo dio el paso porque era lo mejor para Canarias, para el partido y para garantizar la gobernabilidad en Canarias. Aunque finalmente no fuera posible ese acuerdo, su gesto honra su trayectoria política y supone un enorme servicio al partido”, aseguró BarraganJM\n",
            "Hechos y no solo palabras.\n",
            "Para @davidtoledoniz “esta es otra metedura de pata de las Nuevas Generaciones del PP, motivada por las prisas por criticar lo que sea con tal de hacer daño, sin pensar en contrastar datos o dando información sesgada para desinformar”.\n",
            "#CCSeMueve https://t.co/EB7myEKtcH\n",
            "=> Hechos y no solo palabras. Para Davidtoledoniz “esta es otra metedura de pata de las Nuevas Generaciones del PP, motivada por las prisas por criticar lo que sea con tal de hacer daño, sin pensar en contrastar datos o dando información sesgada para desinformar”. CCSeMueve\n",
            "Editorial von SVP-Nationalrätin @verenaherzog: Einschränkungen bei den Sozialleistungen für Ausländer verpasst https://t.co/N5zPn6LcqI\n",
            "=> Editorial von SVP-Nationalrätin Verenaherzog: Einschränkungen bei den Sozialleistungen für Ausländer verpasst\n"
          ]
        }
      ],
      "source": [
        "# test\n",
        "for tweet in dat.text.sample(5).values:\n",
        "  print(tweet)\n",
        "  print('=>', preprocess_tweet_text(tweet))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UrxL1YIbmiKz"
      },
      "source": [
        "Preprocess all tweets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "1G81jyac_SVC"
      },
      "outputs": [],
      "source": [
        "dat['text_clean'] = dat.text.apply(preprocess_tweet_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ps_XARQRBYFT"
      },
      "source": [
        "## Sentence-embed tweets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8fHFnpkz6pwa"
      },
      "source": [
        "### Using the knowledge-distilled XLM-R model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bI-ItfUQ8Br2"
      },
      "source": [
        "download and instantiate the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "nx4--rY_zBbR"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Downloading (…)31d34/.gitattributes: 100%|██████████| 345/345 [00:00<00:00, 85.6kB/s]\n",
            "Downloading (…)_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 321kB/s]\n",
            "Downloading (…)e4a1a31d34/README.md: 100%|██████████| 3.74k/3.74k [00:00<00:00, 2.63MB/s]\n",
            "Downloading (…)a1a31d34/config.json: 100%|██████████| 718/718 [00:00<00:00, 567kB/s]\n",
            "Downloading (…)ce_transformers.json: 100%|██████████| 122/122 [00:00<00:00, 66.5kB/s]\n",
            "Downloading pytorch_model.bin: 100%|██████████| 1.11G/1.11G [00:09<00:00, 117MB/s] \n",
            "Downloading (…)nce_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 99.1kB/s]\n",
            "Downloading (…)tencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:00<00:00, 83.8MB/s]\n",
            "Downloading (…)cial_tokens_map.json: 100%|██████████| 150/150 [00:00<00:00, 249kB/s]\n",
            "Downloading (…)31d34/tokenizer.json: 100%|██████████| 9.10M/9.10M [00:00<00:00, 30.2MB/s]\n",
            "Downloading (…)okenizer_config.json: 100%|██████████| 550/550 [00:00<00:00, 445kB/s]\n",
            "Downloading (…)1a31d34/modules.json: 100%|██████████| 229/229 [00:00<00:00, 419kB/s]\n"
          ]
        }
      ],
      "source": [
        "model = SentenceTransformer('paraphrase-xlm-r-multilingual-v1', device = device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aom8-k7D8EyP"
      },
      "source": [
        "embed the tweets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "4b0ccde04a5042c785af0838218bb36f",
            "6a6c974f75384925900b997becd087cc",
            "5d2d846a0deb4aea8d8b49e51906e2cf",
            "aed3ca39f06e46908fd03ba5b760cb7d",
            "2edaaa00a9a247afa0b28423a960b857",
            "67f6eb6530fc490e985fe849d16de6a6",
            "96d0dbca22784548a3866d82404d48f9",
            "6ab505c1e2974c66897b9102bbc1ac96",
            "be3ec937dba5467b9e79439ed508d2d4",
            "6b45da087f7e47cd9513acf1eaafe256",
            "0fc30a80e1d14e4f803d9a01532ee982"
          ]
        },
        "id": "ou8CJfEW0DDg",
        "outputId": "55c69c73-b7be-48f6-8bc6-79da3e5674d8"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Batches: 100%|██████████| 169/169 [53:02<00:00, 18.83s/it]   \n"
          ]
        }
      ],
      "source": [
        "xlmrs = model.encode(dat.text_clean.values, show_progress_bar = True, device = device)\n",
        "# note: if you are not using GPU (i.e. 'cuda' device), this will take some time (ca. 3 minutes/500 tweets)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jBUvfo1m0hcq",
        "outputId": "9048ac92-db27-4daf-b153-7d83a291019f"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(5401, 768)"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "xlmrs.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RJdzCGWJ8Gst"
      },
      "source": [
        "convert to indexed data frame"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "5DLhx4A81-_6"
      },
      "outputs": [],
      "source": [
        "col_names = [f'e{i+1:04d}' for i in range(xlmrs.shape[1])]\n",
        "xlmrs_df = pd.DataFrame(xlmrs, columns = col_names, index = dat.index.values)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G24X5S-D7BzR"
      },
      "source": [
        "### Using LASER"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JgVVmZA38OkI"
      },
      "source": [
        "In contrast to the model above, LASER requires information about input texts' language.\n",
        "I use the CLD2 library to detect tweets' languages (since Twitter's classifictation is often unreliable)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "6JVUAWGk7F0J"
      },
      "outputs": [],
      "source": [
        "def detect_lang(x):\n",
        "  try:\n",
        "    is_reliable, _, details = cld2.detect(x, isPlainText = True, bestEffort = True)\n",
        "  except:\n",
        "    return str(np.nan)\n",
        "  if not is_reliable:\n",
        "    return str(np.nan)\n",
        "  return details[0][1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "AS3_h39VpFye"
      },
      "outputs": [],
      "source": [
        "dat['lang_guess'] = dat.text_clean.apply(detect_lang)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A7_WGM3M8ixx"
      },
      "source": [
        "inspect the result\n",
        "\n",
        "*Note:* These are ISO-639-2-character codes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KWjre3DCpdW4",
        "outputId": "8ae4258d-deab-42eb-9574-4b5c118656fb"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "lang_guess\n",
              "en     1597\n",
              "de      683\n",
              "nl      532\n",
              "fr      485\n",
              "sv      333\n",
              "da      267\n",
              "es      267\n",
              "fi      233\n",
              "it      231\n",
              "no      227\n",
              "el      178\n",
              "pt      171\n",
              "gl       57\n",
              "ca       56\n",
              "lb       41\n",
              "eu       21\n",
              "nn        7\n",
              "nan       6\n",
              "af        4\n",
              "co        3\n",
              "sw        1\n",
              "ia        1\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat['lang_guess'].value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mjf_C3AJ8wfq"
      },
      "source": [
        "There are six tweets whose language could not been detected automatically."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ii3oME8cqNQn",
        "outputId": "56916cd1-4d9f-4e51-d8ac-5d7212ab1ec1"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "item_id\n",
              "CHE_172375021_21746577302           Succès suisse  Succès PDC: Communiqué du PDC ...\n",
              "CHE_172375021_299521011760656385    Suisse-UE : il nous faut un agenda clair: La q...\n",
              "CHE_172375021_71148201406775297     Les enfants et les jeunes doivent être mieux p...\n",
              "FIN_393938517_195003065122824193    Kasapanokset kärkeen: Lähestyviä MM-jääkieko...\n",
              "AUT_26750370_42224401147502592      #spoe Kampagne Weniger Salz ist g'sünder: Sa...\n",
              "CHE_172375021_281808997617655809    Europa und wir  und wie weiter: Für die EU wi...\n",
              "Name: text, dtype: object"
            ]
          },
          "execution_count": 23,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat[dat.lang_guess == 'nan'].text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AhM5uGXD83au"
      },
      "source": [
        "I label these manually"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "ma-IIvSKqavv"
      },
      "outputs": [],
      "source": [
        "dat.loc['CHE_172375021_21746577302', 'lang_guess'] = 'fr'\n",
        "dat.loc['CHE_172375021_299521011760656385', 'lang_guess'] = 'fr'\n",
        "dat.loc['CHE_172375021_71148201406775297', 'lang_guess'] = 'fr'\n",
        "dat.loc['FIN_393938517_195003065122824193', 'lang_guess'] = 'fi'\n",
        "dat.loc['AUT_26750370_42224401147502592', 'lang_guess'] = 'de'\n",
        "dat.loc['CHE_172375021_281808997617655809', 'lang_guess'] = 'de'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7zXw-8pf86Rb"
      },
      "source": [
        "I can now instantiate the model and embed the tweets:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
            "To disable this warning, you can either:\n",
            "\t- Avoid using `tokenizers` before the fork if possible\n",
            "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading models into /Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/laserembeddings/data\n",
            "\n",
            "✅   Downloaded https://dl.fbaipublicfiles.com/laser/models/93langs.fcodes    \n",
            "✅   Downloaded https://dl.fbaipublicfiles.com/laser/models/93langs.fvocab    \n",
            "⏳   Downloading https://dl.fbaipublicfiles.com/laser/models/bilstm.93langs.2018-12-26.pt..."
          ]
        }
      ],
      "source": [
        "!python3 -m laserembeddings download-models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "id": "HEkCyC087Dev"
      },
      "outputs": [],
      "source": [
        "model = Laser()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "u4GpyyX-rlWy"
      },
      "outputs": [],
      "source": [
        "lasers = model.embed_sentences(dat.text_clean.values, dat.lang_guess.values)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sR7g-0g18-Hd"
      },
      "source": [
        "convert to indexed data frame"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "mgNCjInBsCcl"
      },
      "outputs": [],
      "source": [
        "col_names = [f'e{i+1:04d}' for i in range(lasers.shape[1])]\n",
        "lasers_df = pd.DataFrame(lasers, columns = col_names, index = dat.index.values)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j7VUkxPeqzme"
      },
      "source": [
        "# Preparation for training and evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t4zQ6tGSb7iN"
      },
      "source": [
        "I next split the dataset into the training and test partitions.\n",
        "To do so, I use the `test_` indicator column that comes with the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6fkdzC2S0xnS",
        "outputId": "fd97b75c-a820-4a96-bcf8-5a563d9d851c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "No. train samples: 4342; pos. label proportion: 0.236\n",
            "No. test samples:  1059; pos. label proportion: 0.250\n"
          ]
        }
      ],
      "source": [
        "train_dat = dat[~dat.test_]\n",
        "train_xlmrs = xlmrs_df[~dat.test_]\n",
        "train_lasers = lasers_df[~dat.test_]\n",
        "print(f'No. train samples: {len(train_dat)}; pos. label proportion: {train_dat.label_.values.mean():.3f}')\n",
        "\n",
        "test_dat = dat[dat.test_]\n",
        "test_xlmrs = xlmrs_df[dat.test_]\n",
        "test_lasers = lasers_df[dat.test_]\n",
        "print(f'No. test samples:  {len(test_dat)}; pos. label proportion: {test_dat.label_.values.mean():.3f}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2Nun4MhB8qS3",
        "outputId": "fcee8f10-7535-4fcd-a165-68051cce3012"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "label_\n",
              "0    3318\n",
              "1    1024\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dat.label_.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VbKiq70P1dOP"
      },
      "source": [
        "Next, I load the JSON file that records which tweets are in which CV folds:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9O0I5IHb65Dr",
        "outputId": "cd0b9c9b-e157-403e-c317-7229385936f5"
      },
      "outputs": [],
      "source": [
        "with open(os.path.join(input_path, 'cv_ids.json'), 'r') as f:\n",
        "  cv_folds = json.load(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WHr9kUm99Z2R"
      },
      "source": [
        "Let's veryify that the training data split can be subsetted by these IDs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ybg7Pv48qiuw",
        "outputId": "eba9f4fc-52a2-40b9-8f42-56daa6733afc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "fold 0: # train = 3473, # test = 869\n",
            "fold 1: # train = 3473, # test = 869\n",
            "fold 2: # train = 3474, # test = 868\n",
            "fold 3: # train = 3474, # test = 868\n",
            "fold 4: # train = 3474, # test = 868\n"
          ]
        }
      ],
      "source": [
        "# verify CV splits\n",
        "for i, fold in cv_folds.items():\n",
        "  print(f\"fold {i}: # train = {len(fold['train'])}, # test = {len(fold['val'])}\")\n",
        "  # next two lines would throw error if sths wrong with the indeces\n",
        "  train_dat.loc[ fold['train'] ]\n",
        "  train_dat.loc[ fold['val'] ]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5W3dKOVU9esH"
      },
      "source": [
        "I reconstruct the data sets *indexes* from this information:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "metadata": {
        "id": "4jxJvj7P3ue9"
      },
      "outputs": [],
      "source": [
        "cv_idxs = list()\n",
        "for fold in cv_folds.values():\n",
        "  train_idxs, val_idxs = list(), list()\n",
        "  for idx, id in enumerate(train_dat.index):\n",
        "    if id in fold['train']:\n",
        "      train_idxs.append(idx)\n",
        "    else:\n",
        "      val_idxs.append(idx)\n",
        "  cv_idxs.append( ( np.asarray(train_idxs), np.asarray(val_idxs) ) )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7vH7kBbC9imK"
      },
      "source": [
        "Finally, I define a dictionary for collecting the classifeir evaluation results:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 39,
      "metadata": {
        "id": "AczVzB1Z4_Sf"
      },
      "outputs": [],
      "source": [
        "best_params = dict()\n",
        "cv_res = dict()\n",
        "performances = dict()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g2LE-XghzN5E"
      },
      "source": [
        "# Train and evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "J5ux9UEGG4Ja",
        "outputId": "42962835-0056-4392-ae5e-2bc2f9a54d3b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 180 µs (started: 2023-10-19 20:10:04 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# track time\n",
        "%load_ext autotime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "using 10 CPUs\n",
            "time: 671 µs (started: 2023-10-19 20:10:07 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# get the number of available CPUs\n",
        "import multiprocessing\n",
        "n_cpus = multiprocessing.cpu_count()\n",
        "print(f'using {n_cpus} CPUs')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-JogIdWRmcvb"
      },
      "source": [
        "## Train L2-regularized linear model (Perceptron)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QVqSfhsSO0Z8",
        "outputId": "991a4835-e186-45ed-8666-279ad798b7fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 920 µs (started: 2023-10-19 20:10:11 +02:00)\n"
          ]
        }
      ],
      "source": [
        "nws = 0.5**np.asarray(range(1,5)) # [0.5, 0.25, 0.125, 0.0625]\n",
        "class_weights = [{0: nw, 1: 1-nw} for nw in nws]\n",
        "\n",
        "# the bigger alpha, the stronger regularization (i.e., more coefficients are shrunk towards zero)\n",
        "tuning_params = dict(\n",
        "    alpha = [1e-3, 1e-4, 1e-5],\n",
        "    class_weight = class_weights\n",
        ")\n",
        "\n",
        "# define 'estimator'\n",
        "clf_per = SGDClassifier(\n",
        "  loss = 'hinge',\n",
        "  penalty = 'l2',\n",
        "  fit_intercept = False,\n",
        "  random_state = SEED,\n",
        "  n_jobs = n_cpus\n",
        ")\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_per,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  verbose = 0\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yNCwWfjgv1Wf"
      },
      "source": [
        "### with XLM-R embeddings\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Jmu7dKrKscW6",
        "outputId": "796c8b5b-c74c-49b7-ad3e-1ce3012b8a04"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 14.4 s (started: 2023-10-19 20:10:39 +02:00)\n"
          ]
        }
      ],
      "source": [
        "clf_per_xlmrs = grid_search.fit(train_xlmrs.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YM-hhswstw5n",
        "outputId": "7b0ed2ea-2ee0-4127-e1c3-5dbfab76f6de"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'alpha': 0.001, 'class_weight': {0: 0.25, 1: 0.75}}"
            ]
          },
          "execution_count": 44,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.06 ms (started: 2023-10-19 20:10:54 +02:00)\n"
          ]
        }
      ],
      "source": [
        "key = 'per_xlmrs'\n",
        "cv_res[key] = pd.DataFrame(clf_per_xlmrs.cv_results_)\n",
        "\n",
        "best_params[key] = clf_per_xlmrs.best_params_\n",
        "best_params[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "205Q8E0usyTd"
      },
      "source": [
        "#### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vI7ZBEojQwL8",
        "outputId": "5d85f0dd-cc22-4df2-9742-288bc4c0e32b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.07 ms (started: 2023-10-19 20:10:54 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_per_xlmrs.predict(test_xlmrs.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5beAgLRnRLDw",
        "outputId": "4d0a8c73-7edf-4582-8b0e-6dcd252c54a3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.87      0.76      0.81       794\n",
            "         pos       0.48      0.66      0.56       265\n",
            "\n",
            "    accuracy                           0.74      1059\n",
            "   macro avg       0.68      0.71      0.69      1059\n",
            "weighted avg       0.77      0.74      0.75      1059\n",
            "\n",
            "time: 7.46 ms (started: 2023-10-19 20:11:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "print(classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "kDDQwtp-vS0P",
        "outputId": "5d4f1b61-5890-483f-aa1b-3d49dec9ca5d"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.870690</td>\n",
              "      <td>0.763224</td>\n",
              "      <td>0.813423</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.482094</td>\n",
              "      <td>0.660377</td>\n",
              "      <td>0.557325</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.737488</td>\n",
              "      <td>0.737488</td>\n",
              "      <td>0.737488</td>\n",
              "      <td>0.737488</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.676392</td>\n",
              "      <td>0.711801</td>\n",
              "      <td>0.685374</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.773449</td>\n",
              "      <td>0.737488</td>\n",
              "      <td>0.749338</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.870690  0.763224  0.813423   794.000000\n",
              "pos            0.482094  0.660377  0.557325   265.000000\n",
              "accuracy       0.737488  0.737488  0.737488     0.737488\n",
              "macro avg      0.676392  0.711801  0.685374  1059.000000\n",
              "weighted avg   0.773449  0.737488  0.749338  1059.000000"
            ]
          },
          "execution_count": 47,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 15.1 ms (started: 2023-10-19 20:11:50 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# see https://stackoverflow.com/a/53780589\n",
        "res = classification_report(test_dat.label_, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances[key] = pd.DataFrame(res).transpose()\n",
        "performances[key] # note: accuracy = micro average in binary classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1pr3Apndv40E"
      },
      "source": [
        "### with LASER embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S7UjmII7E8JC",
        "outputId": "13da8ef7-eb3e-4bd6-c973-63ed80a4a16c"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/laser/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 8.74 s (started: 2023-10-19 20:11:59 +02:00)\n"
          ]
        }
      ],
      "source": [
        "clf_per_lasers = grid_search.fit(train_lasers.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yrhQtFR_QT6_",
        "outputId": "63af0110-38a5-44dd-f6d0-a05a4453ee4a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'alpha': 0.0001, 'class_weight': {0: 0.25, 1: 0.75}}"
            ]
          },
          "execution_count": 49,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.5 ms (started: 2023-10-19 20:12:12 +02:00)\n"
          ]
        }
      ],
      "source": [
        "key = 'per_lasers'\n",
        "cv_res[key] = pd.DataFrame(clf_per_lasers.cv_results_)\n",
        "\n",
        "best_params[key] = clf_per_lasers.best_params_\n",
        "best_params[key]\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BVi87-HWtJIe"
      },
      "source": [
        "#### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 50,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Hf06Vyf8tHTO",
        "outputId": "53b6af49-f1a5-4222-c9d5-3652a3f916cd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.89 ms (started: 2023-10-19 20:12:19 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_per_lasers.predict(test_lasers.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xtVajDqEtHTP",
        "outputId": "6b2777fd-f942-4c65-e1d9-6e8789d6d2a3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.86      0.70      0.77       794\n",
            "         pos       0.42      0.66      0.52       265\n",
            "\n",
            "    accuracy                           0.69      1059\n",
            "   macro avg       0.64      0.68      0.64      1059\n",
            "weighted avg       0.75      0.69      0.71      1059\n",
            "\n",
            "time: 6.93 ms (started: 2023-10-19 20:12:21 +02:00)\n"
          ]
        }
      ],
      "source": [
        "print(classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "OpRqrj5wtHTP",
        "outputId": "903c8f50-28cc-4082-fdab-126b09bb2166"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.861371</td>\n",
              "      <td>0.696474</td>\n",
              "      <td>0.770195</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.422062</td>\n",
              "      <td>0.664151</td>\n",
              "      <td>0.516129</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.688385</td>\n",
              "      <td>0.688385</td>\n",
              "      <td>0.688385</td>\n",
              "      <td>0.688385</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.641717</td>\n",
              "      <td>0.680312</td>\n",
              "      <td>0.643162</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.751440</td>\n",
              "      <td>0.688385</td>\n",
              "      <td>0.706619</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.861371  0.696474  0.770195   794.000000\n",
              "pos            0.422062  0.664151  0.516129   265.000000\n",
              "accuracy       0.688385  0.688385  0.688385     0.688385\n",
              "macro avg      0.641717  0.680312  0.643162  1059.000000\n",
              "weighted avg   0.751440  0.688385  0.706619  1059.000000"
            ]
          },
          "execution_count": 52,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 12 ms (started: 2023-10-19 20:12:29 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances[key] = pd.DataFrame(res).transpose()\n",
        "performances[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fNo000B5nhFe"
      },
      "source": [
        "## Train MLP"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9kme_K7fzOq-"
      },
      "source": [
        "I make my life easy and just use the MLP classifier provided by in rhe sklearn package.\n",
        "Implementing the MLP in keras would allow utilizing the GPU and result in faster training, but I'm patient and want to focus on the results, not the code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 59,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oYhQi9Jc2IWs",
        "outputId": "aa1188aa-2efe-478e-af3b-9b45433a6607"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 736 µs (started: 2023-10-19 20:14:32 +02:00)\n"
          ]
        }
      ],
      "source": [
        "tuning_params = dict(\n",
        "  batch_size = [64, 128, 256],\n",
        "  learning_rate_init = [1e-2, 1e-3, 1e-4],\n",
        ")\n",
        "\n",
        "# define the 'estimator'\n",
        "clf_mlp = MLPClassifier(\n",
        "  activation = 'relu',\n",
        "  solver = 'adam',\n",
        "  max_iter = 100,\n",
        "  hidden_layer_sizes = 100,\n",
        "  learning_rate = 'adaptive',\n",
        "  early_stopping = True,\n",
        "  random_state = SEED\n",
        ")\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_mlp,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  #n_jobs = n_cpus,\n",
        "  verbose = 0\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R3IlS5y52SrS"
      },
      "source": [
        "### Using XLM-R embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 60,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mUWi9iqdX6jv",
        "outputId": "cfad7e20-e1cb-4bc9-bf49-13aea82e25e2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 42.4 s (started: 2023-10-19 20:14:34 +02:00)\n"
          ]
        }
      ],
      "source": [
        "%%capture\n",
        "clf_mlp_xlmrs = grid_search.fit(train_xlmrs.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IHvvNPx_oi61"
      },
      "source": [
        "#### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 61,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hAhYcXYYaq4y",
        "outputId": "e3fdfb20-1faf-46f3-809b-2a9074fb1ec6"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'batch_size': 128, 'learning_rate_init': 0.001}"
            ]
          },
          "execution_count": 61,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.04 ms (started: 2023-10-19 20:15:17 +02:00)\n"
          ]
        }
      ],
      "source": [
        "key = 'mlp_xlmrs'\n",
        "cv_res[key] = pd.DataFrame(clf_mlp_xlmrs.cv_results_)\n",
        "\n",
        "best_params[key] = clf_mlp_xlmrs.best_params_\n",
        "best_params[key]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 62,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "imS0eKmmat2P",
        "outputId": "08d7c275-7479-4a56-d25b-b0fd00318dfb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 2.8 ms (started: 2023-10-19 20:15:17 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_mlp_xlmrs.predict(test_xlmrs.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 63,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7x0350UDauTX",
        "outputId": "1942ab4c-00f0-4fed-9c33-ac7cf34a8a7f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.80      0.93      0.86       794\n",
            "         pos       0.60      0.32      0.42       265\n",
            "\n",
            "    accuracy                           0.78      1059\n",
            "   macro avg       0.70      0.62      0.64      1059\n",
            "weighted avg       0.75      0.78      0.75      1059\n",
            "\n",
            "time: 5.25 ms (started: 2023-10-19 20:15:17 +02:00)\n"
          ]
        }
      ],
      "source": [
        "print(classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 64,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "V9wDSds52kQY",
        "outputId": "8d5a4d7c-d3a1-4647-c00d-1044a2aa6744"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.803261</td>\n",
              "      <td>0.930730</td>\n",
              "      <td>0.862310</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.604317</td>\n",
              "      <td>0.316981</td>\n",
              "      <td>0.415842</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.777148</td>\n",
              "      <td>0.777148</td>\n",
              "      <td>0.777148</td>\n",
              "      <td>0.777148</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.703789</td>\n",
              "      <td>0.623856</td>\n",
              "      <td>0.639076</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.753478</td>\n",
              "      <td>0.777148</td>\n",
              "      <td>0.750588</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.803261  0.930730  0.862310   794.000000\n",
              "pos            0.604317  0.316981  0.415842   265.000000\n",
              "accuracy       0.777148  0.777148  0.777148     0.777148\n",
              "macro avg      0.703789  0.623856  0.639076  1059.000000\n",
              "weighted avg   0.753478  0.777148  0.750588  1059.000000"
            ]
          },
          "execution_count": 64,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 13.5 ms (started: 2023-10-19 20:15:37 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances[key] = pd.DataFrame(res).transpose()\n",
        "performances[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U0UTy88N2vxd"
      },
      "source": [
        "### Using LASER embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 65,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NRNU_Wnb2vxd",
        "outputId": "939da566-c8a2-426d-aed8-db24172ba6d7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 39.1 s (started: 2023-10-19 20:15:54 +02:00)\n"
          ]
        }
      ],
      "source": [
        "%%capture\n",
        "clf_mlp_lasers = grid_search.fit(train_lasers.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 66,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w1WQ5VzH2vxe",
        "outputId": "8ac638a3-aeca-481f-df25-f3a72a29587a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'batch_size': 128, 'learning_rate_init': 0.001}"
            ]
          },
          "execution_count": 66,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.1 ms (started: 2023-10-19 20:16:33 +02:00)\n"
          ]
        }
      ],
      "source": [
        "key = 'mlp_lasers'\n",
        "cv_res[key] = pd.DataFrame(clf_mlp_lasers.cv_results_)\n",
        "\n",
        "best_params[key] = clf_mlp_lasers.best_params_\n",
        "best_params[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3O7RJvtw2vxe"
      },
      "source": [
        "#### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 67,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QwkgnQPs2vxe",
        "outputId": "55d160ab-30e1-4f7a-95d7-dc50d3a3b4c0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.45 ms (started: 2023-10-19 20:16:33 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_mlp_lasers.predict(test_lasers.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 68,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mvtovAmv2vxe",
        "outputId": "ab202692-3a95-4fe9-a8e6-a399f30dfb95"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.80      0.91      0.85       794\n",
            "         pos       0.55      0.32      0.41       265\n",
            "\n",
            "    accuracy                           0.76      1059\n",
            "   macro avg       0.68      0.62      0.63      1059\n",
            "weighted avg       0.74      0.76      0.74      1059\n",
            "\n",
            "time: 5.12 ms (started: 2023-10-19 20:16:33 +02:00)\n"
          ]
        }
      ],
      "source": [
        "print(classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 69,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "AHRzqs182vxe",
        "outputId": "b2f1bbe9-f418-43e5-babd-a75cfcd3840c"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.801105</td>\n",
              "      <td>0.913098</td>\n",
              "      <td>0.853443</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.551948</td>\n",
              "      <td>0.320755</td>\n",
              "      <td>0.405728</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.764873</td>\n",
              "      <td>0.764873</td>\n",
              "      <td>0.764873</td>\n",
              "      <td>0.764873</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.676527</td>\n",
              "      <td>0.616926</td>\n",
              "      <td>0.629586</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.738757</td>\n",
              "      <td>0.764873</td>\n",
              "      <td>0.741409</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.801105  0.913098  0.853443   794.000000\n",
              "pos            0.551948  0.320755  0.405728   265.000000\n",
              "accuracy       0.764873  0.764873  0.764873     0.764873\n",
              "macro avg      0.676527  0.616926  0.629586  1059.000000\n",
              "weighted avg   0.738757  0.764873  0.741409  1059.000000"
            ]
          },
          "execution_count": 69,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 11.8 ms (started: 2023-10-19 20:17:06 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances[key] = pd.DataFrame(res).transpose()\n",
        "performances[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KJ_SeWgIYyt_"
      },
      "source": [
        "## Save results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 70,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iPBe19XT8PNI",
        "outputId": "60f00ce7-5559-43a0-97a9-be9ad8714ccd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 14.7 ms (started: 2023-10-19 20:17:20 +02:00)\n"
          ]
        }
      ],
      "source": [
        "fp = os.path.join(res_dir, 'mse_cv_results.tab')\n",
        "pd.concat(cv_res, names = ['model', 'param_set']).to_csv(fp, sep = '\\t')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 71,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0sSELCmt9bvQ",
        "outputId": "4499515d-4764-4477-8898-d086dd0b27e6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.53 ms (started: 2023-10-19 20:17:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "import json\n",
        "fp = os.path.join(res_dir, 'mse_best_params.json')\n",
        "with open(fp, 'w') as f:\n",
        "  json.dump(best_params, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 72,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RMlEh0yz82-f",
        "outputId": "25d23df5-7798-4bf6-e690-7ba330f1fb7d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 3.59 ms (started: 2023-10-19 20:17:24 +02:00)\n"
          ]
        }
      ],
      "source": [
        "fp = os.path.join(res_dir, 'mse_test_results.tab')\n",
        "pd.concat(performances, names = ['model', 'what']).to_csv(fp, sep = '\\t')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.18"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0fc30a80e1d14e4f803d9a01532ee982": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2edaaa00a9a247afa0b28423a960b857": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4b0ccde04a5042c785af0838218bb36f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_6a6c974f75384925900b997becd087cc",
              "IPY_MODEL_5d2d846a0deb4aea8d8b49e51906e2cf",
              "IPY_MODEL_aed3ca39f06e46908fd03ba5b760cb7d"
            ],
            "layout": "IPY_MODEL_2edaaa00a9a247afa0b28423a960b857"
          }
        },
        "5d2d846a0deb4aea8d8b49e51906e2cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6ab505c1e2974c66897b9102bbc1ac96",
            "max": 169,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_be3ec937dba5467b9e79439ed508d2d4",
            "value": 169
          }
        },
        "67f6eb6530fc490e985fe849d16de6a6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6a6c974f75384925900b997becd087cc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_67f6eb6530fc490e985fe849d16de6a6",
            "placeholder": "​",
            "style": "IPY_MODEL_96d0dbca22784548a3866d82404d48f9",
            "value": "Batches: 100%"
          }
        },
        "6ab505c1e2974c66897b9102bbc1ac96": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6b45da087f7e47cd9513acf1eaafe256": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "96d0dbca22784548a3866d82404d48f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "aed3ca39f06e46908fd03ba5b760cb7d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6b45da087f7e47cd9513acf1eaafe256",
            "placeholder": "​",
            "style": "IPY_MODEL_0fc30a80e1d14e4f803d9a01532ee982",
            "value": " 169/169 [00:16&lt;00:00, 23.74it/s]"
          }
        },
        "be3ec937dba5467b9e79439ed508d2d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
